
cd('.')
addpath('matlab functions')

D0 = readtable('NC Data/SCC_4to5.csv');

student_controls = {'lag_mathscore','lag_mathscore_sq'...
    ,'lag_readingscore','lag_readingscore_sq'...
    ,'female','race_black','race_hisp','race_asian','race_other','poverty','lep'};
teacher_controls = {'tch_exp'};
grade_controls = cellstr(strcat('schgrd',string(4:5)));
class_controls = strcat('teacher_',[student_controls {'n' 'n_sq'}]);
school_controls = strcat('school_',[student_controls {'n' 'n_sq'}]);
district_controls = strcat('district_',[student_controls {'n' 'n_sq'}]);
year_controls = arrayfun(@(n) sprintf('yr%02d',n),8:14,'unif',0);
interaction_controls = [strcat('lag_mathscoreX',{'female','race_black','poverty','lep'}) ...
    strcat('lag_readingscoreX',{'female','race_black','poverty','lep'}) ...
    strcat('femaleX',{'race_black','poverty','lep'}) ...
    strcat('race_blackX',{'poverty','lep'}) {'povertyXlep'}];

full_controls = [{'intercept'} student_controls teacher_controls ...
                grade_controls(2:end) class_controls school_controls ...
                district_controls year_controls(2:end) interaction_controls];

clear student_controls teacher_controls grade_controls class_controls school_controls district_controls

for sbj = {'math','reading'}
    
    modspecs = struct('outcome',[sbj{1} 'score']...
                ,'controls',{full_controls}); 

    D = D0(~isnan(D0.(modspecs.outcome)),:);
        
    for mdl = {'NOmatching','matching'}
        
        if strcmp(mdl{1},'NOmatching')
            tchcoef = year_controls;
            numcomp = 1;
        elseif strcmp(mdl{1},'matching')
            tchcoef = [year_controls ...
                {'lag_mathscore' 'lag_mathscore_sq' 'lag_readingscore' 'lag_readingscore_sq'} ...
                {'female','race_black','poverty','lep'} ... 
                interaction_controls(1:10)];
            tchcoef(contains(tchcoef,strcat(setdiff({'math','reading'},sbj{1}),'score'))) = [];
            numcomp = 2;
        end
        
        modspecs.fsfevars = [{'intercept'} tchcoef(~startsWith(tchcoef,'yr'))];
        fsest = fitfsmod(D,modspecs);

        modspecs.tchcoef = tchcoef;
        vadata = genvadata(D,modspecs,fsest.controls);
        rng(0)
        vamdl = fitvadist(vadata,numcomp,{'Display','iter','FunctionTolerance',1e-8});
        vamdl.tchcoef = modspecs.tchcoef;
        vamdl.controls = fsest.controls;
        vamdl.modspecs = modspecs;

        save(sprintf('output/main/vam_%s_%s.mat',sbj{1},mdl{1}),'-struct','vamdl')

    end

end

        



